{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,\n",
    "                              TensorDataset)\n",
    "import argparse\n",
    "import logging\n",
    "import os\n",
    "import json\n",
    "import time\n",
    "import torch.nn.functional as F\n",
    "from preprocess import LCSTSProcessor\n",
    "from model import BertAbsSum\n",
    "from pytorch_pretrained_bert.tokenization import BertTokenizer\n",
    "from pytorch_pretrained_bert.modeling import BertModel\n",
    "from pytorch_pretrained_bert.optimization import BertAdam\n",
    "from preprocess import convert_examples_to_features\n",
    "from tqdm import tqdm, trange\n",
    "from transformer import Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',\n",
    "                    datefmt='%m/%d/%Y %H:%M:%S',\n",
    "                    level=logging.INFO)\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "class ARGS(object):\n",
    "    data_dir = 'data/processed_data'\n",
    "    bert_model = 'pretrained_model'\n",
    "    output_dir = 'output'\n",
    "    learning_rate = 5e-5\n",
    "    num_train_epochs = 3\n",
    "    warmup_proportion = 0.1\n",
    "    max_src_len = 130\n",
    "    max_tgt_len = 30\n",
    "    train_batch_size = 1\n",
    "    decoder_config = None\n",
    "    print_every = 10\n",
    "    gradient_accumulation_steps = 1\n",
    "    GPU_index = '1'\n",
    "\n",
    "args = ARGS()\n",
    "\n",
    "def cal_loss(draft_logits, refine_logits, ground):\n",
    "    ground = ground[:, 1:]\n",
    "    draft_logits = draft_logits.view(-1, draft_logits.size(-1))\n",
    "    refine_logits = draft_logits.view(-1, refine_logits.size(-1))\n",
    "    ground = ground.contiguous().view(-1)\n",
    "    draft_loss = F.cross_entropy(draft_logits, ground, ignore_index=Constants.PAD)\n",
    "    refine_loss = F.cross_entropy(refine_logits, ground, ignore_index=Constants.PAD)\n",
    "    return draft_loss + refine_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "str expected, not int",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-12-b2db2516f06b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGPU_index\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'-1'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m     \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menviron\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"CUDA_VISIBLE_DEVICES\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;31m#args.GPU_index\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cuda\"\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_available\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"cpu\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mn_gpu\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_count\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_batch_size\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mn_gpu\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/sum/lib/python3.6/os.py\u001b[0m in \u001b[0;36m__setitem__\u001b[0;34m(self, key, value)\u001b[0m\n\u001b[1;32m    672\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__setitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    673\u001b[0m         \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencodekey\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 674\u001b[0;31m         \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencodevalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    675\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mputenv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    676\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/sum/lib/python3.6/os.py\u001b[0m in \u001b[0;36mencode\u001b[0;34m(value)\u001b[0m\n\u001b[1;32m    742\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    743\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 744\u001b[0;31m                 \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"str expected, not %s\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    745\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mencoding\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'surrogateescape'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    746\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: str expected, not int"
     ]
    }
   ],
   "source": [
    "if args.GPU_index != '-1':\n",
    "    os.environ[\"CUDA_VISIBLE_DEVICES\"] = args.GPU_index\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "n_gpu = torch.cuda.device_count()\n",
    "assert args.train_batch_size % n_gpu == 0\n",
    "logger.info(f'Using device:{torch.cuda.current_device()}, n_gpu:{n_gpu}')\n",
    "\n",
    "if not os.path.exists(args.output_dir):\n",
    "    os.makedirs(args.output_dir)\n",
    "model_path = os.path.join(args.output_dir, time.strftime('model_%m-%d-%H:%M:%S', time.localtime()))\n",
    "os.mkdir(model_path)\n",
    "logger.info(f'Saving model to {model_path}.')\n",
    "\n",
    "if args.decoder_config is not None:\n",
    "    with open(args.decoder_config, 'r') as f:\n",
    "        decoder_config = json.load(f)\n",
    "else:\n",
    "    with open(os.path.join(args.bert_model, 'bert_config.json'), 'r') as f:\n",
    "        bert_config = json.load(f)\n",
    "        decoder_config = {}\n",
    "        decoder_config['len_max_seq'] = args.max_tgt_len\n",
    "        decoder_config['d_word_vec'] = bert_config['hidden_size']\n",
    "        decoder_config['n_layers'] = 8\n",
    "        decoder_config['n_head'] = 12\n",
    "        decoder_config['d_k'] = 64\n",
    "        decoder_config['d_v'] = 64\n",
    "        decoder_config['d_inner'] = bert_config['hidden_size']\n",
    "        decoder_config['d_model'] = bert_config['hidden_size']\n",
    "        decoder_config['vocab_size'] = bert_config['vocab_size']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "# data preprocess\n",
    "processor = LCSTSProcessor()\n",
    "tokenizer = BertTokenizer.from_pretrained(os.path.join(args.bert_model, 'vocab.txt'))\n",
    "logger.info('Loading train examples...')\n",
    "train_examples = processor.get_train_examples('data/processed_data')\n",
    "num_train_optimization_steps = int(len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs\n",
    "logger.info('Converting train examples to features...')\n",
    "features = convert_examples_to_features(train_examples, args.max_src_len, args.max_tgt_len, tokenizer)\n",
    "example = train_examples[0]\n",
    "example_feature = features[0]\n",
    "logger.info(\"*** Example ***\")\n",
    "logger.info(\"guid: %s\" % (example.guid))\n",
    "logger.info(\"src text: %s\" % example.src)\n",
    "logger.info(\"src_ids: %s\" % \" \".join([str(x) for x in example_feature.src_ids]))\n",
    "logger.info(\"src_mask: %s\" % \" \".join([str(x) for x in example_feature.src_mask]))\n",
    "logger.info(\"tgt text: %s\" % example.tgt)\n",
    "logger.info(\"tgt_ids: %s\" % \" \".join([str(x) for x in example_feature.tgt_ids]))\n",
    "logger.info(\"tgt_mask: %s\" % \" \".join([str(x) for x in example_feature.tgt_mask]))\n",
    "logger.info('Building dataloader...')\n",
    "all_src_ids = torch.tensor([f.src_ids for f in features], dtype=torch.long)\n",
    "all_src_mask = torch.tensor([f.src_mask for f in features], dtype=torch.long)\n",
    "all_tgt_ids = torch.tensor([f.tgt_ids for f in features], dtype=torch.long)\n",
    "all_tgt_mask = torch.tensor([f.tgt_mask for f in features], dtype=torch.long)\n",
    "train_data = TensorDataset(all_src_ids, all_src_mask, all_tgt_ids, all_tgt_mask)\n",
    "train_sampler = RandomSampler(train_data)\n",
    "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "06/20/2019 20:45:26 - INFO - pytorch_pretrained_bert.modeling -   loading archive file pretrained_model\n",
      "06/20/2019 20:45:26 - INFO - pytorch_pretrained_bert.modeling -   Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"directionality\": \"bidi\",\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pooler_fc_size\": 768,\n",
      "  \"pooler_num_attention_heads\": 12,\n",
      "  \"pooler_num_fc_layers\": 3,\n",
      "  \"pooler_size_per_head\": 128,\n",
      "  \"pooler_type\": \"first_token_transform\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"vocab_size\": 21128\n",
      "}\n",
      "\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "CUDA error: out of memory",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-13-62bedd2cb8e7>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;31m# model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBertAbsSum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbert_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecoder_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mn_gpu\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataParallel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/sum/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    379\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    380\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 381\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    382\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    383\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_backward_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/sum/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    185\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    186\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parameters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/sum/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    185\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    186\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parameters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/sum/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    185\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    186\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m             \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    189\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mparam\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parameters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/sum/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m    191\u001b[0m                 \u001b[0;31m# Tensors stored in modules are graph leaves, and we don't\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    192\u001b[0m                 \u001b[0;31m# want to create copy nodes, so we have to unpack the data.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 193\u001b[0;31m                 \u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    194\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_grad\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    195\u001b[0m                     \u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_grad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_grad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/sum/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m    377\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    378\u001b[0m         \u001b[0;32mdef\u001b[0m \u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 379\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    380\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    381\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory"
     ]
    }
   ],
   "source": [
    "from utils import count_parameters\n",
    "\n",
    "# model\n",
    "model = BertAbsSum(args.bert_model, decoder_config, device)\n",
    "model.to(device)\n",
    "if n_gpu > 1:\n",
    "    model = torch.nn.DataParallel(model)\n",
    "logger.info(f'Total parameters:{count_parameters(model)}')\n",
    "\n",
    "# optimizer\n",
    "param_optimizer = list(model.named_parameters())\n",
    "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
    "optimizer_grouped_parameters = [\n",
    "    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},\n",
    "    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]\n",
    "optimizer = BertAdam(optimizer_grouped_parameters,\n",
    "                     lr=args.learning_rate,\n",
    "                     warmup=0.1,\n",
    "                     t_total=num_train_optimization_steps)\n",
    "\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pdb\n",
    "\n",
    "logger.info(\"***** Running training *****\")\n",
    "logger.info(\"  Num examples = %d\", len(train_examples))\n",
    "logger.info(\"  Batch size = %d\", args.train_batch_size)\n",
    "logger.info(\"  Num steps = %d\", num_train_optimization_steps)\n",
    "model.train()\n",
    "global_step = 0\n",
    "for i in range(int(args.num_train_epochs)):\n",
    "    tr_loss = 0\n",
    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader, desc=\"Iteration\")):\n",
    "        batch = tuple(t.to(device) for t in batch)\n",
    "        draft_logits, refine_logits = model(*batch)\n",
    "        loss = cal_loss(draft_logits, refine_logits, batch[2])\n",
    "        if n_gpu > 1:\n",
    "            loss = loss.mean()\n",
    "        if args.gradient_accumulation_steps > 1:\n",
    "            loss = loss / args.gradient_accumulation_steps\n",
    "        pdb.set_trace()\n",
    "        loss.backward()\n",
    "        tr_loss += loss.item()\n",
    "        nb_tr_examples += batch[0].size(0)\n",
    "        nb_tr_steps += 1\n",
    "        if (step + 1) % args.gradient_accumulation_steps == 0:\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "            global_step += 1\n",
    "        if step % args.print_every == 0:\n",
    "            logger.info(f'Epoch {i}, step {step}, loss {loss.item()}.')\n",
    "    torch.save(model.state_dict(), os.join(model_path, 'BertAbsSum.bin'))\n",
    "    logger.info(f'Epoch {i} finished. Model saved.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
